# ------------------------------------------------------------------------------------
# NeRF-Factory
# Copyright (c) 2022 POSTECH, KAIST, Kakao Brain Corp. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------
# Modified from DVGO (https://github.com/sunset1995/DirectVoxGO)
# Copyright (c) 2022 Google LLC. All Rights Reserved.
# ------------------------------------------------------------------------------------

import functools
import os
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.model.dvgo.__global__ import *


def create_grid(type, **kwargs):
    if type == "DenseGrid":
        return DenseGrid(**kwargs)
    elif type == "TensoRFGrid":
        return TensoRFGrid(**kwargs)
    else:
        raise NotImplementedError


""" Dense 3D grid
"""


class DenseGrid(nn.Module):
    def __init__(self, channels, world_size, xyz_min, xyz_max, **kwargs):

        super(DenseGrid, self).__init__()
        self.channels = channels
        self.world_size = world_size
        if isinstance(xyz_min, np.ndarray):
            xyz_min, xyz_max = torch.from_numpy(xyz_min), torch.from_numpy(xyz_max)
        self.register_buffer("xyz_min", xyz_min.cuda())
        self.register_buffer("xyz_max", xyz_max.cuda())
        self.grid = nn.Parameter(torch.zeros([1, channels, *world_size]).cuda())

    def forward(self, xyz):
        """
        xyz: global coordinates to query
        """
        shape = xyz.shape[:-1]
        xyz = xyz.reshape(1, 1, 1, -1, 3)
        ind_norm = (
            (xyz.cuda() - self.xyz_min.cuda())
            / (self.xyz_max.cuda() - self.xyz_min.cuda())
        ).flip((-1,)) * 2 - 1
        out = F.grid_sample(self.grid, ind_norm, mode="bilinear", align_corners=True)
        out = out.reshape(self.channels, -1).T.reshape(*shape, self.channels)
        if self.channels == 1:
            out = out.squeeze(-1)
        return out

    def scale_volume_grid(self, new_world_size):
        if self.channels == 0:
            self.grid = nn.Parameter(torch.zeros([1, self.channels, *new_world_size]))
        else:
            self.grid = nn.Parameter(
                F.interpolate(
                    self.grid.data,
                    size=tuple(new_world_size),
                    mode="trilinear",
                    align_corners=True,
                )
            )

    def total_variation_add_grad(self, wx, wy, wz, dense_mode):
        """Add gradients by total variation loss in-place"""
        from src.model.dvgo.__global__ import total_variation_cuda

        total_variation_cuda.total_variation_add_grad(
            self.grid, self.grid.grad, wx, wy, wz, dense_mode
        )

    def get_dense_grid(self):
        return self.grid

    @torch.no_grad()
    def __isub__(self, val):
        self.grid.data -= val
        return self

    def extra_repr(self):
        return f"channels={self.channels}, world_size={self.world_size.tolist()}"


""" Vector-Matrix decomposited grid
See TensoRF: Tensorial Radiance Fields (https://arxiv.org/abs/2203.09517)
"""


class TensoRFGrid(nn.Module):
    def __init__(self, channels, world_size, xyz_min, xyz_max, config):
        super(TensoRFGrid, self).__init__()
        self.channels = channels
        self.world_size = world_size
        self.config = config
        self.register_buffer("xyz_min", torch.Tensor(xyz_min))
        self.register_buffer("xyz_max", torch.Tensor(xyz_max))
        X, Y, Z = world_size
        R = config["n_comp"]
        Rxy = config.get("n_comp_xy", R)
        self.xy_plane = nn.Parameter(torch.randn([1, Rxy, X, Y]) * 0.1)
        self.xz_plane = nn.Parameter(torch.randn([1, R, X, Z]) * 0.1)
        self.yz_plane = nn.Parameter(torch.randn([1, R, Y, Z]) * 0.1)
        self.x_vec = nn.Parameter(torch.randn([1, R, X, 1]) * 0.1)
        self.y_vec = nn.Parameter(torch.randn([1, R, Y, 1]) * 0.1)
        self.z_vec = nn.Parameter(torch.randn([1, Rxy, Z, 1]) * 0.1)
        if self.channels > 1:
            self.f_vec = nn.Parameter(torch.ones([R + R + Rxy, channels]))
            nn.init.kaiming_uniform_(self.f_vec, a=np.sqrt(5))

    def forward(self, xyz):
        """
        xyz: global coordinates to query
        """
        shape = xyz.shape[:-1]
        xyz = xyz.reshape(1, 1, -1, 3)
        ind_norm = (xyz - self.xyz_min) / (self.xyz_max - self.xyz_min) * 2 - 1
        ind_norm = torch.cat([ind_norm, torch.zeros_like(ind_norm[..., [0]])], dim=-1)
        if self.channels > 1:
            out = compute_tensorf_feat(
                self.xy_plane,
                self.xz_plane,
                self.yz_plane,
                self.x_vec,
                self.y_vec,
                self.z_vec,
                self.f_vec,
                ind_norm,
            )
            out = out.reshape(*shape, self.channels)
        else:
            out = compute_tensorf_val(
                self.xy_plane,
                self.xz_plane,
                self.yz_plane,
                self.x_vec,
                self.y_vec,
                self.z_vec,
                ind_norm,
            )
            out = out.reshape(*shape)
        return out

    def scale_volume_grid(self, new_world_size):
        if self.channels == 0:
            return
        X, Y, Z = new_world_size
        self.xy_plane = nn.Parameter(
            F.interpolate(
                self.xy_plane.data, size=[X, Y], mode="bilinear", align_corners=True
            )
        )
        self.xz_plane = nn.Parameter(
            F.interpolate(
                self.xz_plane.data, size=[X, Z], mode="bilinear", align_corners=True
            )
        )
        self.yz_plane = nn.Parameter(
            F.interpolate(
                self.yz_plane.data, size=[Y, Z], mode="bilinear", align_corners=True
            )
        )
        self.x_vec = nn.Parameter(
            F.interpolate(
                self.x_vec.data, size=[X, 1], mode="bilinear", align_corners=True
            )
        )
        self.y_vec = nn.Parameter(
            F.interpolate(
                self.y_vec.data, size=[Y, 1], mode="bilinear", align_corners=True
            )
        )
        self.z_vec = nn.Parameter(
            F.interpolate(
                self.z_vec.data, size=[Z, 1], mode="bilinear", align_corners=True
            )
        )

    def total_variation_add_grad(self, wx, wy, wz, dense_mode):
        """Add gradients by total variation loss in-place"""
        loss = (
            wx
            * F.smooth_l1_loss(
                self.xy_plane[:, :, 1:], self.xy_plane[:, :, :-1], reduction="sum"
            )
            + wy
            * F.smooth_l1_loss(
                self.xy_plane[:, :, :, 1:], self.xy_plane[:, :, :, :-1], reduction="sum"
            )
            + wx
            * F.smooth_l1_loss(
                self.xz_plane[:, :, 1:], self.xz_plane[:, :, :-1], reduction="sum"
            )
            + wz
            * F.smooth_l1_loss(
                self.xz_plane[:, :, :, 1:], self.xz_plane[:, :, :, :-1], reduction="sum"
            )
            + wy
            * F.smooth_l1_loss(
                self.yz_plane[:, :, 1:], self.yz_plane[:, :, :-1], reduction="sum"
            )
            + wz
            * F.smooth_l1_loss(
                self.yz_plane[:, :, :, 1:], self.yz_plane[:, :, :, :-1], reduction="sum"
            )
            + wx
            * F.smooth_l1_loss(
                self.x_vec[:, :, 1:], self.x_vec[:, :, :-1], reduction="sum"
            )
            + wy
            * F.smooth_l1_loss(
                self.y_vec[:, :, 1:], self.y_vec[:, :, :-1], reduction="sum"
            )
            + wz
            * F.smooth_l1_loss(
                self.z_vec[:, :, 1:], self.z_vec[:, :, :-1], reduction="sum"
            )
        )
        loss /= 6
        loss.backward()

    def get_dense_grid(self):
        if self.channels > 1:
            feat = torch.cat(
                [
                    torch.einsum(
                        "rxy,rz->rxyz", self.xy_plane[0], self.z_vec[0, :, :, 0]
                    ),
                    torch.einsum(
                        "rxz,ry->rxyz", self.xz_plane[0], self.y_vec[0, :, :, 0]
                    ),
                    torch.einsum(
                        "ryz,rx->rxyz", self.yz_plane[0], self.x_vec[0, :, :, 0]
                    ),
                ]
            )
            grid = torch.einsum("rxyz,rc->cxyz", feat, self.f_vec)[None]
        else:
            grid = (
                torch.einsum("rxy,rz->xyz", self.xy_plane[0], self.z_vec[0, :, :, 0])
                + torch.einsum("rxz,ry->xyz", self.xz_plane[0], self.y_vec[0, :, :, 0])
                + torch.einsum("ryz,rx->xyz", self.yz_plane[0], self.x_vec[0, :, :, 0])
            )
            grid = grid[None, None]
        return grid

    def extra_repr(self):
        return f'channels={self.channels}, world_size={self.world_size.tolist()}, n_comp={self.config["n_comp"]}'


def compute_tensorf_feat(
    xy_plane, xz_plane, yz_plane, x_vec, y_vec, z_vec, f_vec, ind_norm
):
    # Interp feature (feat shape: [n_pts, n_comp])
    xy_feat = (
        F.grid_sample(
            xy_plane, ind_norm[:, :, :, [1, 0]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    xz_feat = (
        F.grid_sample(
            xz_plane, ind_norm[:, :, :, [2, 0]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    yz_feat = (
        F.grid_sample(
            yz_plane, ind_norm[:, :, :, [2, 1]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    x_feat = (
        F.grid_sample(
            x_vec, ind_norm[:, :, :, [3, 0]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    y_feat = (
        F.grid_sample(
            y_vec, ind_norm[:, :, :, [3, 1]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    z_feat = (
        F.grid_sample(
            z_vec, ind_norm[:, :, :, [3, 2]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    # Aggregate components
    feat = torch.cat(
        [
            xy_feat * z_feat,
            xz_feat * y_feat,
            yz_feat * x_feat,
        ],
        dim=-1,
    )
    feat = torch.mm(feat, f_vec)
    return feat


def compute_tensorf_val(xy_plane, xz_plane, yz_plane, x_vec, y_vec, z_vec, ind_norm):
    # Interp feature (feat shape: [n_pts, n_comp])
    xy_feat = (
        F.grid_sample(
            xy_plane, ind_norm[:, :, :, [1, 0]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    xz_feat = (
        F.grid_sample(
            xz_plane, ind_norm[:, :, :, [2, 0]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    yz_feat = (
        F.grid_sample(
            yz_plane, ind_norm[:, :, :, [2, 1]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    x_feat = (
        F.grid_sample(
            x_vec, ind_norm[:, :, :, [3, 0]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    y_feat = (
        F.grid_sample(
            y_vec, ind_norm[:, :, :, [3, 1]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    z_feat = (
        F.grid_sample(
            z_vec, ind_norm[:, :, :, [3, 2]], mode="bilinear", align_corners=True
        )
        .flatten(0, 2)
        .T
    )
    # Aggregate components
    feat = (
        (xy_feat * z_feat).sum(-1)
        + (xz_feat * y_feat).sum(-1)
        + (yz_feat * x_feat).sum(-1)
    )
    return feat


""" Mask grid
It supports query for the known free space and unknown space.
"""


class MaskGrid(nn.Module):
    def __init__(
        self, path=None, mask_cache_thres=None, mask=None, xyz_min=None, xyz_max=None
    ):

        super(MaskGrid, self).__init__()

        if path is not None:
            st = torch.load(path)
            self.mask_cache_thres = mask_cache_thres
            density = F.max_pool3d(
                st["model_state_dict"]["density.grid"],
                kernel_size=3,
                padding=1,
                stride=1,
            )
            alpha = 1 - torch.exp(
                -F.softplus(density + st["model_state_dict"]["act_shift"])
                * st["model_kwargs"]["voxel_size_ratio"]
            )
            mask = (alpha >= self.mask_cache_thres).squeeze(0).squeeze(0)
            xyz_min = torch.Tensor(st["model_kwargs"]["xyz_min"])
            xyz_max = torch.Tensor(st["model_kwargs"]["xyz_max"])
        else:
            mask = mask.bool().cuda()
            xyz_min = xyz_min.cuda()
            xyz_max = xyz_max.cuda()

        self.register_buffer("mask", mask)
        xyz_len = xyz_max - xyz_min
        self.register_buffer(
            "xyz2ijk_scale",
            (torch.Tensor(list(mask.shape)).cuda() - 1) / xyz_len.cuda(),
        )
        self.register_buffer("xyz2ijk_shift", -xyz_min.cuda() * self.xyz2ijk_scale)

    @torch.no_grad()
    def forward(self, xyz):
        """Skip know freespace
        @xyz:   [..., 3] the xyz in global coordinate.
        """
        shape = xyz.shape[:-1]
        xyz = xyz.reshape(-1, 3)
        from src.model.dvgo.__global__ import render_utils_cuda

        mask = render_utils_cuda.maskcache_lookup(
            self.mask.cuda(),
            xyz.cuda(),
            self.xyz2ijk_scale.cuda(),
            self.xyz2ijk_shift.cuda(),
        )
        mask = mask.reshape(shape)
        return mask

    def extra_repr(self):
        return f"mask.shape=list(self.mask.shape)"
